package com.hyj.spark.offline

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, IDF, StringIndexer}
import org.apache.spark.sql.SparkSession
/**预测某篇文章属于哪个类别：
1 根据训练集（每篇文章有对应的一个类别）
遍历每篇文章，计算文章总数，各类别下的文章数，每篇文章再遍历单词，计算各类别下各单词词频
2 计算各类别的占比=各类别下的文章数/文章总数 计算各类别各单词的占比（即得到训练模型）
3 给定测试集（要测试的文章数，已知每篇文章属于哪个类别），
根据训练模型，遍历文章，遍历单词，遍历类别，
计算每篇测试文章所有单词，在不同类别的汇总得分+各类别的占比，筛选出得分最高的类别
4 准确度=预测对的文章数/总测试文章数
*/
object TextPredictBayesHyj {
  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "root")
    //    val warehouse = "hdfs://192.168.163.130:9000/user/hive/warehouse"

    val spark = SparkSession.builder()
      .master("local[2]")
      .appName("User Base")
      .enableHiveSupport()
      .getOrCreate()

    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

    val df = spark.sql("select * from badou.news_seg")
      .selectExpr("split(sentence,'##@@##')[0] as sentence", "split(sentence,'##@@##')[1] as label")
      .selectExpr("split(sentence,' ') as sentence", "label")
    //      df.show()
    //    setBinary(true)   bernoulli伯努利分布统计出来的单词只表示出现或者没出现。一个单词在文章中出现了10次，这个单词就是1，没有出现0
    //    setBinary(false)  multinomial正常词频统计，一个单词出现10次，值为10,这种就是多项式分布
    val coltf = "feature_tf"
    val hashingTF = new HashingTF()
      //涉及到统计的数据分布式伯努利分布（0,1）还是*多项式分布*（词频分布）
      .setBinary(false)
//      .setBinary(true)
      .setInputCol("sentence")
      .setOutputCol(coltf)
      .setNumFeatures(1 << 18) //表示词表大小262144，对应是文本向量的大小
    //计算词频
    val df_tf = hashingTF.transform(df).select(coltf, "label")
//        df_tf.show(1,false)
    /** sparseVector (向量大小，[index],[value])   [0,1,0,2,0,0,0,0,0,3] (10,[1,3,9],[1,2,3])
      * 元组-feature_tf:词表大小，单词编码集合，词频集合
      * (262144,
      * [25092,30776,32170,37868,41381,46345,47509,52820,53214,54571,56356,59063,68415,85364,101501,111202,115360,119175,122747,137701,145507,146391,152497,154668,156037,158988,161609,170555,172050,174205,181458,183965,184919,195842,197957,198177,205710,210829,226542,236972,238938,241664,242670,243109,243937,249040,249180,250396,254957],
      * [1.0,1.0,1.0,7.0,5.0,1.0,1.0,1.0,1.0,13.0,1.0,1.0,1.0,1.0,8.0,1.0,9.0,1.0,7.0,1.0,1.0,1.0,3.0,5.0,1.0,1.0,2.0,2.0,1.0,2.0,1.0,1.0,8.0,1.0,1.0,1.0,1.0,11.0,6.0,3.0,1.0,1.0,2.0,1.0,2.0,1.0,1.0,7.0,1.0]
      * )
      * 字符串-label:auto
      * */
      val tdIdfCol="feature_tfidf"
    val idf = new IDF().setInputCol(coltf)
      .setOutputCol(tdIdfCol)
      .setMinDocFreq(2) // 取文档频率>=2

    val idfModel = idf.fit(df_tf)
    //某个词的反文档率=log(语料库文章总数/包含该词的文章数)
    val df_tfIdf = idfModel.transform(df_tf).select(tdIdfCol, "label")
//    df_tfIdf.show(1, false)
    /**
      * 元组-feature_tfidf:词表大小,单词编码集合，单词反文档率集合
      * (262144,
      * [25092,30776,32170,37868,41381,46345,47509,52820,53214,54571,56356,59063,68415,85364,101501,111202,115360,119175,122747,137701,145507,146391,152497,154668,156037,158988,161609,170555,172050,174205,181458,183965,184919,195842,197957,198177,205710,210829,226542,236972,238938,241664,242670,243109,243937,249040,249180,250396,254957],
      * [2.2223857314469218,0.7803437255901005,3.5863743350525144,5.351292635036672,12.414360305451064,4.915510282332456,3.399162792964368,2.730708224994794,6.707269751560512,37.70789440327249,3.5717755356313616,2.1054384668379345,1.9887708802654167,0.028927636906179426,11.766622309844497,1.9536795604541466,34.691744894824744,4.670387824299471,5.332944000677553,1.4139649268360188,0.8304681127726692,4.092309973524313,5.254328081877752,6.485564516284695,2.1639749692905075,3.060949911865371,12.39288825558904,5.316775126830336,6.707269751560512,5.3401672063567185,0.0,5.859971891173307,0.0032593196621134475,3.8546383216471938,0.377548846037815,1.9365851270948466,3.0522918491222564,2.9120254360154387,10.508656163755504,1.0052647981006768,2.8714081070979294,4.340146137428895,2.788127545037448,2.08229693827624,7.143551071262723,4.438586210242147,0.0,21.80173910599634,1.0431525332967415]
      * )
      * 字符串-label:auto
      */
    val labelCol="indexed"
    val indexer=new StringIndexer()
      .setInputCol("label")
      .setOutputCol(labelCol)
      //处理非规则数据的方式：error 抛异常 默认skip 跳过
      .setHandleInvalid("error")
    //给每篇文章（每行数据），根据label加索引
    val df_tf_idf_index=indexer.fit(df_tf).transform(df_tf)
//    df_tf_idf.show(3,false)
    //将数据八二分为训练集和测试集
    val Array(train,test)=df_tf_idf_index.randomSplit(Array(0.8,0.2))

//    train.show(3,false)

    val predCol="predict"
    //设置贝叶斯参数
    val nb=new NaiveBayes()
      .setModelType("multinomial")
      //伯努利模型 词频=0,1 该词在文章中出现过多次为1 未出现为0
//      .setModelType("bernoulli")
      .setSmoothing(1)
      .setFeaturesCol(coltf)
      .setLabelCol(labelCol)
      .setPredictionCol(predCol)
      .setProbabilityCol("prob")
      //得分
      .setRawPredictionCol("raw")
    //根据训练集得到贝叶斯模型
    val model=nb.fit(train)

    //获取测试集的预测结果
    val pred=model.transform(test)
//    pred.show(3,false)
    /**
      * feature_tf
      * (262144,[36,1166,1894,3006,3333,3817,4457,5224,5294,6191,7023,11995,12549,13764,14244,14506,15088,15882,17078,17104,18606,19035,19136,19973,20440,21055,23411,24107,24380,25171,25203,25288,25672,26081,27289,27508,28906,29985,30073,30776,33456,37868,38641,39148,40400,41103,42178,42461,44196,45741,46201,48557,48674,49806,49853,50129,50730,51973,52278,53223,55793,56407,56428,56541,56545,59140,60935,60951,63249,63717,63857,63971,64000,64720,65677,68686,69941,70520,72289,74461,76308,76679,77212,78148,80084,80095,80215,80738,82008,82159,83383,85364,86245,86552,86660,86846,87162,88194,88229,89143,91015,91304,92026,92273,96418,96889,97875,98600,99443,99654,101043,101446,101626,102414,103103,104977,105320,105438,106746,106764,107659,110643,110850,111826,112026,114635,114919,116900,117173,117468,119663,121895,122694,122747,123117,123183,124794,126178,126856,126985,127019,127800,128306,129387,129393,129558,132137,132264,132511,133350,134077,134815,134878,136072,136973,137315,139048,140433,140646,140851,141584,141811,143488,143541,145140,145855,146494,147229,147831,148594,150810,151098,151845,152561,153503,155428,156029,156662,159859,160658,161387,162869,162896,163214,164528,164743,166259,166385,167671,167981,168422,169587,170050,170745,172658,172879,176396,176443,177684,177809,178463,178987,179021,179283,179908,181458,181978,182120,184919,185630,186175,187617,187953,188016,188230,188250,189263,189268,189423,190771,191534,192028,194392,195025,197034,198032,198322,202094,202700,203424,204535,205828,208734,209272,209419,209493,209892,210829,210979,213503,216188,216432,216464,216486,217340,217399,217405,220258,220693,221490,222162,223093,223273,223357,224298,224941,225036,225473,225667,227297,227564,228351,228791,229862,230057,230252,230428,231338,231456,231730,231916,233539,234913,235372,238030,239006,239343,240585,240741,240865,241948,242377,242863,243492,243937,244183,244767,244808,246324,247032,247259,247898,248181,248189,249180,250309,250823,250829,251284,251951,252929,253356,253614,254399,255455,256058,257403,258048,259149,259189,259241,259739,261197,261982],[1.0,1.0,1.0,1.0,1.0,3.0,1.0,8.0,1.0,1.0,2.0,1.0,5.0,5.0,3.0,3.0,7.0,1.0,5.0,1.0,1.0,1.0,1.0,6.0,9.0,1.0,2.0,1.0,1.0,2.0,15.0,6.0,1.0,2.0,2.0,1.0,1.0,1.0,1.0,1.0,7.0,1.0,2.0,1.0,1.0,1.0,3.0,5.0,4.0,2.0,1.0,1.0,14.0,1.0,2.0,3.0,1.0,1.0,6.0,2.0,1.0,2.0,1.0,2.0,4.0,1.0,1.0,1.0,1.0,3.0,1.0,4.0,3.0,8.0,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,4.0,11.0,1.0,5.0,6.0,5.0,2.0,3.0,1.0,48.0,3.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0,3.0,1.0,2.0,2.0,3.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,1.0,3.0,17.0,3.0,1.0,1.0,1.0,1.0,1.0,18.0,2.0,1.0,1.0,6.0,1.0,2.0,1.0,1.0,1.0,5.0,3.0,1.0,1.0,1.0,7.0,1.0,2.0,1.0,3.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,4.0,1.0,1.0,2.0,2.0,6.0,5.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,3.0,1.0,1.0,4.0,1.0,1.0,1.0,6.0,1.0,2.0,1.0,1.0,2.0,1.0,2.0,1.0,2.0,2.0,3.0,1.0,1.0,1.0,2.0,2.0,2.0,1.0,2.0,1.0,1.0,1.0,2.0,1.0,1.0,2.0,1.0,1.0,1.0,1.0,2.0,1.0,3.0,1.0,1.0,1.0,2.0,4.0,1.0,1.0,1.0,1.0,4.0,2.0,6.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,4.0,1.0,1.0,2.0,1.0,2.0,1.0,6.0,1.0,8.0,1.0,1.0,1.0,1.0,1.0,4.0,3.0,1.0,1.0,3.0,1.0,21.0,1.0,1.0,3.0,1.0,1.0,28.0,12.0,1.0,1.0,1.0,5.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0,1.0,1.0,4.0,8.0,1.0,1.0,2.0,1.0,1.0])
      * label：sports
      * indexed：3.0
      * raw：[-6914.96541158401,-6860.240488757132,-6827.381583553035,-5996.708142886078,-6764.1470360452995]
      * prob：[0.0,0.0,0.0,1.0,0.0]
      * predict：3.0
      */

    val eval=new MulticlassClassificationEvaluator()
      .setLabelCol(labelCol)
      .setPredictionCol(predCol)
      .setMetricName("accuracy")
    //获取预测的准确度=测试集-预测正确个数/测试集个数
    val  accuracy=eval.evaluate(pred)

    print("accuracy:"+accuracy)


  }
}
